{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "%matplotlib inline"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "\n# Spatial Transformer Networks Tutorial\n**Author**: [Ghassen HAMROUNI](https://github.com/GHamrouni)\n\n.. figure:: /_static/img/stn/FSeq.png\n\nIn this tutorial, you will learn how to augment your network using\na visual attention mechanism called spatial transformer\nnetworks. You can read more about the spatial transformer\nnetworks in the [DeepMind paper](https://arxiv.org/abs/1506.02025)_\n\nSpatial transformer networks are a generalization of differentiable\nattention to any spatial transformation. Spatial transformer networks\n(STN for short) allow a neural network to learn how to perform spatial\ntransformations on the input image in order to enhance the geometric\ninvariance of the model.\nFor example, it can crop a region of interest, scale and correct\nthe orientation of an image. It can be a useful mechanism because CNNs\nare not invariant to rotation and scale and more general affine\ntransformations.\n\nOne of the best things about STN is the ability to simply plug it into\nany existing CNN with very little modification.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "# License: BSD\n# Author: Ghassen Hamrouni\n\nfrom __future__ import print_function\nimport torch\nimport torch.nn as nn\nimport torch.nn.functional as F\nimport torch.optim as optim\nimport torchvision\nfrom torchvision import datasets, transforms\nimport matplotlib.pyplot as plt\nimport numpy as np\n\nplt.ion()   # interactive mode"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Loading the data\n\nIn this post we experiment with the classic MNIST dataset. Using a\nstandard convolutional network augmented with a spatial transformer\nnetwork.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "from six.moves import urllib\nopener = urllib.request.build_opener()\nopener.addheaders = [('User-agent', 'Mozilla/5.0')]\nurllib.request.install_opener(opener)\n\ndevice = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n\n# Training dataset\ntrain_loader = torch.utils.data.DataLoader(\n    datasets.MNIST(root='.', train=True, download=True,\n                   transform=transforms.Compose([\n                       transforms.ToTensor(),\n                       transforms.Normalize((0.1307,), (0.3081,))\n                   ])), batch_size=64, shuffle=True, num_workers=4)\n# Test dataset\ntest_loader = torch.utils.data.DataLoader(\n    datasets.MNIST(root='.', train=False, transform=transforms.Compose([\n        transforms.ToTensor(),\n        transforms.Normalize((0.1307,), (0.3081,))\n    ])), batch_size=64, shuffle=True, num_workers=4)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Depicting spatial transformer networks\n\nSpatial transformer networks boils down to three main components :\n\n-  The localization network is a regular CNN which regresses the\n   transformation parameters. The transformation is never learned\n   explicitly from this dataset, instead the network learns automatically\n   the spatial transformations that enhances the global accuracy.\n-  The grid generator generates a grid of coordinates in the input\n   image corresponding to each pixel from the output image.\n-  The sampler uses the parameters of the transformation and applies\n   it to the input image.\n\n.. figure:: /_static/img/stn/stn-arch.png\n\n.. Note::\n   We need the latest version of PyTorch that contains\n   affine_grid and grid_sample modules.\n\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "class Net(nn.Module):\n    def __init__(self):\n        super(Net, self).__init__()\n        self.conv1 = nn.Conv2d(1, 10, kernel_size=5)\n        self.conv2 = nn.Conv2d(10, 20, kernel_size=5)\n        self.conv2_drop = nn.Dropout2d()\n        self.fc1 = nn.Linear(320, 50)\n        self.fc2 = nn.Linear(50, 10)\n\n        # Spatial transformer localization-network\n        self.localization = nn.Sequential(\n            nn.Conv2d(1, 8, kernel_size=7),\n            nn.MaxPool2d(2, stride=2),\n            nn.ReLU(True),\n            nn.Conv2d(8, 10, kernel_size=5),\n            nn.MaxPool2d(2, stride=2),\n            nn.ReLU(True)\n        )\n\n        # Regressor for the 3 * 2 affine matrix\n        self.fc_loc = nn.Sequential(\n            nn.Linear(10 * 3 * 3, 32),\n            nn.ReLU(True),\n            nn.Linear(32, 3 * 2)\n        )\n\n        # Initialize the weights/bias with identity transformation\n        self.fc_loc[2].weight.data.zero_()\n        self.fc_loc[2].bias.data.copy_(torch.tensor([1, 0, 0, 0, 1, 0], dtype=torch.float))\n\n    # Spatial transformer network forward function\n    def stn(self, x):\n        xs = self.localization(x)\n        xs = xs.view(-1, 10 * 3 * 3)\n        theta = self.fc_loc(xs)\n        theta = theta.view(-1, 2, 3)\n\n        grid = F.affine_grid(theta, x.size())\n        x = F.grid_sample(x, grid)\n\n        return x\n\n    def forward(self, x):\n        # transform the input\n        x = self.stn(x)\n\n        # Perform the usual forward pass\n        x = F.relu(F.max_pool2d(self.conv1(x), 2))\n        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))\n        x = x.view(-1, 320)\n        x = F.relu(self.fc1(x))\n        x = F.dropout(x, training=self.training)\n        x = self.fc2(x)\n        return F.log_softmax(x, dim=1)\n\n\nmodel = Net().to(device)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Training the model\n\nNow, let's use the SGD algorithm to train the model. The network is\nlearning the classification task in a supervised way. In the same time\nthe model is learning STN automatically in an end-to-end fashion.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "optimizer = optim.SGD(model.parameters(), lr=0.01)\n\n\ndef train(epoch):\n    model.train()\n    for batch_idx, (data, target) in enumerate(train_loader):\n        data, target = data.to(device), target.to(device)\n\n        optimizer.zero_grad()\n        output = model(data)\n        loss = F.nll_loss(output, target)\n        loss.backward()\n        optimizer.step()\n        if batch_idx % 500 == 0:\n            print('Train Epoch: {} [{}/{} ({:.0f}%)]\\tLoss: {:.6f}'.format(\n                epoch, batch_idx * len(data), len(train_loader.dataset),\n                100. * batch_idx / len(train_loader), loss.item()))\n#\n# A simple test procedure to measure the STN performances on MNIST.\n#\n\n\ndef test():\n    with torch.no_grad():\n        model.eval()\n        test_loss = 0\n        correct = 0\n        for data, target in test_loader:\n            data, target = data.to(device), target.to(device)\n            output = model(data)\n\n            # sum up batch loss\n            test_loss += F.nll_loss(output, target, size_average=False).item()\n            # get the index of the max log-probability\n            pred = output.max(1, keepdim=True)[1]\n            correct += pred.eq(target.view_as(pred)).sum().item()\n\n        test_loss /= len(test_loader.dataset)\n        print('\\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\\n'\n              .format(test_loss, correct, len(test_loader.dataset),\n                      100. * correct / len(test_loader.dataset)))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Visualizing the STN results\n\nNow, we will inspect the results of our learned visual attention\nmechanism.\n\nWe define a small helper function in order to visualize the\ntransformations while training.\n\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "collapsed": false
      },
      "outputs": [],
      "source": [
        "def convert_image_np(inp):\n    \"\"\"Convert a Tensor to numpy image.\"\"\"\n    inp = inp.numpy().transpose((1, 2, 0))\n    mean = np.array([0.485, 0.456, 0.406])\n    std = np.array([0.229, 0.224, 0.225])\n    inp = std * inp + mean\n    inp = np.clip(inp, 0, 1)\n    return inp\n\n# We want to visualize the output of the spatial transformers layer\n# after the training, we visualize a batch of input images and\n# the corresponding transformed batch using STN.\n\n\ndef visualize_stn():\n    with torch.no_grad():\n        # Get a batch of training data\n        data = next(iter(test_loader))[0].to(device)\n\n        input_tensor = data.cpu()\n        transformed_input_tensor = model.stn(data).cpu()\n\n        in_grid = convert_image_np(\n            torchvision.utils.make_grid(input_tensor))\n\n        out_grid = convert_image_np(\n            torchvision.utils.make_grid(transformed_input_tensor))\n\n        # Plot the results side-by-side\n        f, axarr = plt.subplots(1, 2)\n        axarr[0].imshow(in_grid)\n        axarr[0].set_title('Dataset Images')\n\n        axarr[1].imshow(out_grid)\n        axarr[1].set_title('Transformed Images')\n\nfor epoch in range(1, 20 + 1):\n    train(epoch)\n    test()\n\n# Visualize the STN transformation on some input batch\nvisualize_stn()\n\nplt.ioff()\nplt.show()"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.4"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}